训练机器翻译¶
Note
本节实现一个可以训练机器翻译模型的函数。
损失函数¶
每个时间步,解码器预测了输出词元的概率分布。类似于语言模型,可以使用softmax获得分布,并通过计算交叉熵损失来进行优化。
但是填充词元应该被排除在损失函数的计算之外。
import torch
from torch import nn
import d2l
#@save
def sequence_mask(X, valid_len, value=0):
"""在序列中屏蔽不想关的项"""
# `X` shape: (`batch_size`, `num_steps`)
# [None, :] makes (`num_steps`,) to (1, `num_steps`)
# [:, None] makes (`batch_size`) to (`batch_size`, 1)
mask = torch.arange((X.size(1)), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
#@save
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
"""带屏蔽的softmax交叉熵损失函数"""
# shape of pred: (`batch_size`, `num_steps`, `vocab_size`)
# shape of label: (`batch_size`, `num_steps`)
# shape of valid_len: (`batch_size`,)
def forward(self, pred, label, valid_len):
# 非pad为1,pad为0
weights = torch.ones_like(label)
weights = sequence_mask(weights, valid_len)
# 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken
self.reduction = 'none'
# nn.CrossEntropyLoss((`batch_size`, `vocab_size`, `num_steps`), (`batch_size`, `num_steps`))
unweighted_loss = super().forward(pred.permute(0, 2, 1), label)
# 得到带屏蔽的损失
weighted_loss = (unweighted_loss * weights).mean(dim=1)
return weighted_loss
训练¶
训练时,我们解码器的输入不是采样自上一步的输出,而是<bos>
+ 真实的输出序列,这被称为teacher-forcing,它能使我们训练得更快。
#@save
def train_nmt(net, data_iter, lr, num_epochs, tgt_vocab):
"""训练机器翻译模型"""
device = d2l.try_gpu()
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = MaskedSoftmaxCELoss()
net.train() # 用了Dropout,必须明示
# 画带屏蔽的交叉熵损失
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs])
for epoch in range(num_epochs):
# 损失和,tokens总数
metric = d2l.Accumulator(2)
for batch in data_iter:
X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
# 解码器的输入是<bos>+真实输出序列
bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
device=device).reshape(-1, 1)
dec_input = torch.cat([bos, Y[:, :-1]], 1)
# 模型需是Encoder-Decoder结构
Y_hat, _ = net(X, dec_input, X_valid_len)
# Backpropagation
optimizer.zero_grad()
loss = loss_fn(Y_hat, Y, Y_valid_len)
loss.sum().backward()
optimizer.step()
# 记录数据
with torch.no_grad():
metric.add(loss.sum(), Y_valid_len.sum())
# 画图
if (epoch + 1) % 5 == 0:
animator.add(epoch + 1, (metric[0] / metric[1],))
预测¶
预测时,我们没有真实的输出序列,解码器当前时间步的输入都将来自于前一时间步的输出词元。
#@save
def predict_nmt(net, src_sentence, src_vocab, tgt_vocab, num_steps, device=d2l.try_gpu()):
"""机器翻译模型做预测"""
net.eval()
# 处理src_sentence
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
enc_X = torch.unsqueeze(
torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
enc_outputs = net.encoder(enc_X, enc_valid_len)
# 解码器初始state及初始输入
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
dec_X = torch.unsqueeze(
torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
output_seq, attention_weight_seq = [], []
# 一步一步来
for _ in range(num_steps):
Y, dec_state = net.decoder(dec_X, dec_state)
# We use the token with the highest prediction likelihood as the input
# of the decoder at the next time step
dec_X = Y.argmax(dim=2)
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
# Once the end-of-sequence token is predicted, the generation of the
# output sequence is complete
if pred == tgt_vocab['<eos>']:
break
output_seq.append(pred)
return ' '.join(tgt_vocab.to_tokens(output_seq))
评估¶
我们可以通过与真实标签序列做比较来评估预测序列。
用 \(p_{n}\) 表示 \(n\)元语法的精确度,它是两个数量的比值,分子是预测序列与标签序列中匹配的 \(n\)元语法的数量,分母是预测序列中 \(n\)元语法的数量。
那么, BLEU 的定义是:
\[\exp\left(\min\left(0, 1 - \frac{\mathrm{len}_{\text{label}}}{\mathrm{len}_{\text{pred}}}\right)\right)\prod_{i=1}^{k}p_{n}^{1/{2^{n}}}\]
其中 \(k\) 是用于匹配的最长 \(n\)元语法,指数项用于惩罚较短的预测序列。
#@save
def bleu(pred_seq, label_seq, k):
"""计算 BLEU"""
pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
len_pred, len_label = len(pred_tokens), len(label_tokens)
score = math.exp(min(0, 1 - len_label / len_pred))
# 计算n元语法的精确度
for n in range(1, k + 1):
num_matches, label_subs = 0, collections.defaultdict(int)
# 统计标签序列中各n元语法的数量
for i in range(len_label - n + 1):
label_subs[''.join(label_tokens[i:i + n])] += 1
# 计算匹配
for i in range(len_pred - n + 1):
if label_subs[''.join(pred_tokens[i:i + n])] > 0:
num_matches += 1
label_subs[''.join(pred_tokens[i:i + n])] -= 1
score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
return score